import torch
from torch import nn

img_mask = torch.zeros((1, 10, 10, 1))
print(img_mask)
